import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:21"

import random
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from transformers import (
    AutoTokenizer, AutoModel, AutoConfig, TrainingArguments, Trainer,
    default_data_collator, EvalPrediction, PreTrainedModel, logging
)

logging.set_verbosity_error()

class Args:
    csv_path         = "/share/home/zhangshanqi/QSM/cpu/trl/data/奖励模型训练数据.csv"
    model_name       = "/share/home/zhangshanqi/QSM/RLHF/model/chatglm3_6b"
    tokenizer_path   = model_name
    max_length       = 2048
    batch_size       = 4
    num_train_epochs = 15
    learning_rate    = 1e-5
    save_dir         = "/share/home/zhangshanqi/QSM/cpu/trl/专家04RM"
    split_seed       = 42

args = Args()

class RewardModel(PreTrainedModel):
    supports_gradient_checkpointing = True

    def __init__(self, config):
        super().__init__(config)
        config.return_dict = True
        config.output_hidden_states = True
        config.gradient_checkpointing = True

        self.backbone = AutoModel.from_pretrained(
            config.model_name_or_path,
            config=config,
            trust_remote_code=True,
            torch_dtype=torch.float16
        )
        self.backbone.gradient_checkpointing_enable()
        for p in self.backbone.parameters():
            p.requires_grad = False

        self.reward_head = nn.Sequential(
            nn.Linear(config.hidden_size, 1),
            nn.Tanh()
        )
    def gradient_checkpointing_enable(self, *args, **kwargs):
        self.backbone.gradient_checkpointing_enable()

    def forward(self, input_ids, attention_mask=None):
        with torch.no_grad():
            outputs = self.backbone(
                input_ids=input_ids,
                attention_mask=attention_mask,
                use_cache=False
            )
            last_hidden = outputs.hidden_states[-1].permute(1, 0, 2)

        seq_lens = attention_mask.sum(dim=1).clamp(min=1) - 1
        seq_lens = torch.clamp(seq_lens, 0, last_hidden.size(1) - 1)
        batch_idx = torch.arange(last_hidden.size(0), device=last_hidden.device)
        last_token = last_hidden[batch_idx, seq_lens]

        return 5.0 * self.reward_head(last_token).squeeze(-1)  # 输出范围变成 [-5, 5]


class PairwiseDataset(Dataset):
    def __init__(self, csv_path, tokenizer, max_length=2048, pairs=None):
        self.tokenizer = tokenizer
        self.max_length = max_length
        if pairs is not None:
            self.pairs = pairs
        else:
            df = pd.read_csv(csv_path, encoding="utf-8-sig")
            df.columns = [c.strip().lower() for c in df.columns]
            df = df.dropna(subset=["prompt","方案1","方案2","labels"])
            self.pairs = []
            for _, row in df.iterrows():
                p, a, b = str(row["prompt"]), str(row["方案1"]), str(row["方案2"])
                try:
                    lbl = int(row["labels"])
                except:
                    continue
                if lbl==1: self.pairs.append((p,a,b))
                elif lbl==2: self.pairs.append((p,b,a))

    def __len__(self):
        return len(self.pairs)

    def tokenize(self, prompt, plan):
        enc = self.tokenizer(
            prompt + plan,
            truncation=True,
            max_length=self.max_length,
            padding="max_length",
            return_tensors="pt"
        )
        return enc.input_ids[0], enc.attention_mask[0]

    def __getitem__(self, idx):
        p, chosen, rejected = self.pairs[idx]
        cid, cm = self.tokenize(p, chosen)
        rid, rm = self.tokenize(p, rejected)
        return {
            "input_ids": torch.stack([cid, rid]),
            "attention_mask": torch.stack([cm, rm]),
            "labels": torch.tensor(0)
        }

class RewardTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        ids = inputs["input_ids"]
        mask = inputs["attention_mask"]
        chosen_ids, rejected_ids = ids[:, 0], ids[:, 1]
        chosen_mask, rejected_mask = mask[:, 0], mask[:, 1]
        input_ids = torch.cat([chosen_ids, rejected_ids], dim=0)
        attention_mask = torch.cat([chosen_mask, rejected_mask], dim=0)
        scores = model(input_ids, attention_mask=attention_mask)
        cs, rs = scores[:len(scores) // 2], scores[len(scores) // 2:]
        loss = -F.logsigmoid(cs - rs).mean()
        return (loss, None) if return_outputs else loss

    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
        with torch.no_grad():
            loss = self.compute_loss(model, inputs, return_outputs=False)
            ids = inputs["input_ids"]
            mask = inputs["attention_mask"]
            cs = model(ids[:, 0], attention_mask=mask[:, 0])
            rs = model(ids[:, 1], attention_mask=mask[:, 1])
            logits = torch.stack([cs, rs], dim=1)
            return loss, logits, torch.ones(ids.size(0), device=logits.device)

def compute_metrics(eval_pred: EvalPrediction):
    preds = torch.tensor(eval_pred.predictions)
    acc = (preds[:, 0] > preds[:, 1]).float().mean().item()
    return {"pairwise_accuracy": acc}

def save_reward_model(model, tokenizer, save_directory):
    os.makedirs(save_directory, exist_ok=True)
    model.backbone.save_pretrained(save_directory)
    tokenizer.save_pretrained(save_directory)
    torch.save(model.reward_head.state_dict(), os.path.join(save_directory, "reward_head.pt"))

def main():
    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token or "[PAD]"

    full = PairwiseDataset(args.csv_path, tokenizer, max_length=args.max_length)
    idxs = list(range(len(full)))
    random.seed(args.split_seed); random.shuffle(idxs)
    cut = int(0.8 * len(idxs))
    train = [full.pairs[i] for i in idxs[:cut]]
    dev = [full.pairs[i] for i in idxs[cut:]]
    train_ds = PairwiseDataset(args.csv_path, tokenizer, args.max_length, train)
    dev_ds = PairwiseDataset(args.csv_path, tokenizer, args.max_length, dev)

    config = AutoConfig.from_pretrained(args.model_name, trust_remote_code=True)
    config.model_name_or_path = args.model_name
    model = RewardModel(config)

    training_args = TrainingArguments(
        output_dir=args.save_dir,
        per_device_train_batch_size=args.batch_size,
        num_train_epochs=args.num_train_epochs,
        learning_rate=args.learning_rate,
        eval_strategy="epoch",
        save_total_limit=2,
        logging_steps=10,
        gradient_accumulation_steps=1,
        gradient_checkpointing=True,
        remove_unused_columns=False,
        report_to="tensorboard",
        logging_dir="/share/home/zhangshanqi/QSM/chatglm4int/chatglm3-models/logs/04tblog",  # ✅ 变更路径到这里
        fp16=True
    )

    trainer = RewardTrainer(
        model=model,
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=dev_ds,
        data_collator=default_data_collator,
        compute_metrics=compute_metrics
    )

    trainer.train()
    save_reward_model(model, tokenizer, args.save_dir)
    print("✅ 最终验证集准确率：", trainer.evaluate())

if __name__ == "__main__":
    main()
